{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Scalable Additive-Structure GP Classification (CUDA) (w/ KISS-GP)\n",
    "\n",
    "## Introduction\n",
    "\n",
    "This example shows how to use a `AdditiveGridInducingVariationalGP` module. This classifcation module is designed for when the function you’re modeling has an additive decomposition over dimension. This is equivalent to using a covariance function that additively decomposes over dimensions:\n",
    "\n",
    "$$k(\\mathbf{x},\\mathbf{x'}) = \\sum_{i=1}^{d}k([\\mathbf{x}]_{i}, [\\mathbf{x'}]_{i})$$\n",
    "\n",
    "where $[\\mathbf{x}]_{i}$ denotes the ith component of the vector $\\mathbf{x}$. Example applications of this include use in Bayesian optimization, and when performing deep kernel learning. \n",
    "\n",
    "The use of inducing points allows for scaling up the training data by making computational complexity linear instead of cubic in the number of data points.\n",
    "\n",
    "\n",
    "In this example, we’re performing classification on a two dimensional toy dataset that is:\n",
    "- Defined in [-1, 1]x[-1, 1]\n",
    "- Valued 1 in [-0.5, 0.5]x[-0.5, 0.5]\n",
    "- Valued -1 otherwise\n",
    "\n",
    "The above function doesn't have an obvious additive decomposition, but it turns out that this function is can be very well approximated by the kernel anyways."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# High-level imports\n",
    "import math\n",
    "from math import exp\n",
    "import torch\n",
    "import gpytorch\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "# Make inline plots\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Generate toy dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "n = 51\n",
    "train_x = torch.zeros(n ** 2, 2)\n",
    "train_x[:, 0].copy_(torch.linspace(-1, 1, n).repeat(n))\n",
    "train_x[:, 1].copy_(torch.linspace(-1, 1, n).unsqueeze(1).repeat(1, n).view(-1))\n",
    "train_y = (train_x[:, 0].abs().lt(0.5)).float() * (train_x[:, 1].abs().lt(0.5)).float()\n",
    "\n",
    "train_x = train_x.cuda()\n",
    "train_y = train_y.cuda()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define the model\n",
    "\n",
    "In contrast to the most basic classification models, this model uses an `AdditiveGridInterpolationVariationalStrategy`. This causes two key changes in the model. First, the model now specifically assumes that the input to `forward`, `x`, is to be additive decomposed. Thus, although the model below defines an `RBFKernel` as the covariance function, because we extend this base class, the additive decomposition discussed above will be imposed. \n",
    "\n",
    "Second, this model automatically assumes we will be using scalable kernel interpolation (SKI) for each dimension. Because of the additive decomposition, we only provide one set of grid bounds to the base class constructor, as the same grid will be used for all dimensions. It is recommended that you scale your training and test data appropriately."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from gpytorch.models import AbstractVariationalGP\n",
    "from gpytorch.variational import AdditiveGridInterpolationVariationalStrategy, CholeskyVariationalDistribution\n",
    "from gpytorch.kernels import RBFKernel, ScaleKernel\n",
    "from gpytorch.likelihoods import BernoulliLikelihood\n",
    "from gpytorch.means import ConstantMean\n",
    "from gpytorch.distributions import MultivariateNormal\n",
    "\n",
    "class GPClassificationModel(AbstractVariationalGP):\n",
    "    def __init__(self, grid_size=64, grid_bounds=([-1, 1],)):\n",
    "        variational_distribution = CholeskyVariationalDistribution(num_inducing_points=grid_size, batch_size=2)\n",
    "        variational_strategy = AdditiveGridInterpolationVariationalStrategy(self,\n",
    "                                                                            grid_size=grid_size,\n",
    "                                                                            grid_bounds=grid_bounds,\n",
    "                                                                            num_dim=2,\n",
    "                                                                            variational_distribution=variational_distribution)\n",
    "        super(GPClassificationModel, self).__init__(variational_strategy)\n",
    "        self.mean_module = ConstantMean()\n",
    "        self.covar_module = ScaleKernel(RBFKernel(ard_num_dims=1))\n",
    "\n",
    "    def forward(self, x):\n",
    "        mean_x = self.mean_module(x)\n",
    "        covar_x = self.covar_module(x)\n",
    "        latent_pred = MultivariateNormal(mean_x, covar_x)\n",
    "        return latent_pred\n",
    "\n",
    "# Cuda the model and likelihood function\n",
    "model = GPClassificationModel().cuda()\n",
    "likelihood = gpytorch.likelihoods.BernoulliLikelihood().cuda()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training the model\n",
    "\n",
    "Once the model has been defined, the training loop looks very similar to other variational models we've seen in the past. We will optimize the variational lower bound as our objective function. In this case, although variational inference in GPyTorch supports stochastic gradient descent, we choose to do batch optimization due to the relatively small toy dataset.\n",
    "\n",
    "For an example of using the `AdditiveGridInducingVariationalGP` model with stochastic gradient descent, see the `dkl_mnist` example."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iter 1/50 - Loss: 1.115\n",
      "Iter 2/50 - Loss: 1.108\n",
      "Iter 3/50 - Loss: 1.873\n",
      "Iter 4/50 - Loss: 1.382\n",
      "Iter 5/50 - Loss: 1.366\n",
      "Iter 6/50 - Loss: 1.286\n",
      "Iter 7/50 - Loss: 1.021\n",
      "Iter 8/50 - Loss: 0.837\n",
      "Iter 9/50 - Loss: 0.809\n",
      "Iter 10/50 - Loss: 0.801\n",
      "Iter 11/50 - Loss: 0.706\n",
      "Iter 12/50 - Loss: 0.611\n",
      "Iter 13/50 - Loss: 0.575\n",
      "Iter 14/50 - Loss: 0.550\n",
      "Iter 15/50 - Loss: 0.513\n",
      "Iter 16/50 - Loss: 0.474\n",
      "Iter 17/50 - Loss: 0.450\n",
      "Iter 18/50 - Loss: 0.426\n",
      "Iter 19/50 - Loss: 0.395\n",
      "Iter 20/50 - Loss: 0.369\n",
      "Iter 21/50 - Loss: 0.351\n",
      "Iter 22/50 - Loss: 0.336\n",
      "Iter 23/50 - Loss: 0.318\n",
      "Iter 24/50 - Loss: 0.301\n",
      "Iter 25/50 - Loss: 0.282\n",
      "Iter 26/50 - Loss: 0.264\n",
      "Iter 27/50 - Loss: 0.250\n",
      "Iter 28/50 - Loss: 0.238\n",
      "Iter 29/50 - Loss: 0.231\n",
      "Iter 30/50 - Loss: 0.217\n",
      "Iter 31/50 - Loss: 0.205\n",
      "Iter 32/50 - Loss: 0.194\n",
      "Iter 33/50 - Loss: 0.189\n",
      "Iter 34/50 - Loss: 0.180\n",
      "Iter 35/50 - Loss: 0.173\n",
      "Iter 36/50 - Loss: 0.166\n",
      "Iter 37/50 - Loss: 0.160\n",
      "Iter 38/50 - Loss: 0.155\n",
      "Iter 39/50 - Loss: 0.149\n",
      "Iter 40/50 - Loss: 0.141\n",
      "Iter 41/50 - Loss: 0.138\n",
      "Iter 42/50 - Loss: 0.141\n",
      "Iter 43/50 - Loss: 0.131\n",
      "Iter 44/50 - Loss: 0.135\n",
      "Iter 45/50 - Loss: 0.122\n",
      "Iter 46/50 - Loss: 0.119\n",
      "Iter 47/50 - Loss: 0.123\n",
      "Iter 48/50 - Loss: 0.116\n",
      "Iter 49/50 - Loss: 0.107\n",
      "Iter 50/50 - Loss: 0.108\n",
      "CPU times: user 5.95 s, sys: 640 ms, total: 6.59 s\n",
      "Wall time: 7.13 s\n"
     ]
    }
   ],
   "source": [
    "# Find optimal model hyperparameters\n",
    "model.train()\n",
    "likelihood.train()\n",
    "\n",
    "# Use the adam optimizer\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.1)\n",
    "\n",
    "# \"Loss\" for GPs - the marginal log likelihood\n",
    "# n_data refers to the amount of training data\n",
    "mll = gpytorch.mlls.VariationalELBO(likelihood, model, num_data=train_y.numel())\n",
    "\n",
    "# Training function\n",
    "def train(num_iter=50):\n",
    "    for i in range(num_iter):\n",
    "        optimizer.zero_grad()\n",
    "        output = model(train_x)\n",
    "        loss = -mll(output, train_y)\n",
    "        loss.backward()\n",
    "        print('Iter %d/%d - Loss: %.3f' % (i + 1, num_iter, loss.item()))\n",
    "        optimizer.step()\n",
    "\n",
    "%time train()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test the model\n",
    "\n",
    "Next we test the model and plot the decision boundary. Despite the function we are optimizing not having an obvious additive decomposition, the model provides accurate results."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.collections.PathCollection at 0x7fbfdc12a908>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQgAAADACAYAAAD4Ov2SAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAGoBJREFUeJztnU+MJcddx7/9+r1582fnz87aQxI7EMZBOZBw2C0JJA4gMT5FSIBWsZCQiJDsHLghYSyBtPLNMQckhJDtEC4cIisr8ScgDl4fggJKlPKi+JjIQ8CxE08yu29nd9/8e+8Vh+qe16+n/1VXdXdV9+8j7U7X6+5f1a/qV7/609XVnhACBEEQSfSaTgBBEPZCDoIgiFTIQRAEkQo5CIIgUiEHQRBEKuQgCIJIpd90AkJeeuklet5KEA3xyiuveEm/W+MgAODll1/Ovebg4AA7Ozs1pKZ62qJLW/QAuqnLrVu3Us8ZGWIwxq5nnLvJGNtjjL1oIi6CIOpD20EwxvYAfD3l3HUA4JzfATDKciQEQdiHtoMIKv9+yunnAIyC430Ae7rxEQRRH1U/xdgCcC8SvlZxfARBGMSqSco8eq+9hq133kF/OoXwfWB7G3j4EN7JCcRwCGxsAPfvw5tMINbWgJUV4N49eEJAbG4CALzRCKLXA65eBY6P4R0fS1lXrwJHR/BOTyGWl4H19bms9XVgaQnePenrxNYWMJvBe/BAytreBsZjeOMxxGAAbG0BoxG883NgZQXiyhWZjtlMyur34d2/j6vHx/A/8Qng/Bzeo0dzWWE6lpaAzc0LWZd02tgAPE+mw/OkDqenUla/v6hTNH+mUylreXkxf4SAd3Q0l3VyInUKZY1G8M7O5jodHsITAlu+D39jA979+zJ/rl4FptO5rGvXgEePZF7H8kesrgKrqzJds5nUqdeDNxoBnifz+vwc3sOHMh3b28CDB5d1KlLmYTmFZb69fUmnqx9+iP5gMC/z+/cBIWQ6Qp1CWY8fp+u0tnaRP2JjA/B9KSuuk+/L/Dk6Km/HJyfwHj9eLPOzM/S/9CVAc8K1agcxArAdHG8BOMy6+ODgIFPYUy+9hPWTk4uw8H1402l6uNeDN5vJ4+A3L+Fc3r2XZHmerExlZXkePCFwJXJcJF5bdVqP6ZGrU145RWXl5YfhMr+SUE66+ZOrU0V27A+HOPjUp6BDJQ6CMbbFOR8BeBMAC37eBXAn6768RzKXHtRGMqZQWOXa+Gvw0fPxc1nX5l2fJytPdtY5HZ1UZanopCJbRYekcNl442GTZa5qtyV1Wu73sanZgzDxFOOm/MNuRn5+GwA453eDa/YAjMJwaWjvCoIojoH6ot2D4JzfBnA79tuNyPEbunG0ia64OIGEHh/hHG69i+FZanLUs2kHJsuxJTbhloOwFVsdF0FoQg6iDnJaE9fbmtz0t6Q1XaCNOiVADsIEHTEWIkJHeo3OOYjaqqLJSt8RY0rFFf0j6ay6V1SLHRvId+cchJNQD8MN2lZOBvQhB5GGSmtCVEpt+e9KT6dGyEF0iejy54biJdyCHETLoKpYEy70NmgOQo/OVSYXjNoVLOoVVZmSTjsIQgOVClKTY7KnylpC5yYpFQzNKmOhltsarLKLAjSdXrccRMep21iaNs5CWNTVtw6ag7CHTDMlI05FO2dM9s4sHDZpQQ4iRlMVsSkjJRqlspKyyAba5SBc8OptQSWvLTJ4Y2ToZFWvSJN2OQiTqBi1SWOxyDhM0kIXYT+de4oRp40tkwMIgPK+I7jtIDyvlS1TG3WynXDn6WYiF9aWudsOoiKaNpbWYTgvW5dDFg8ryUG0HBH925TzcWWvxzY6Z03cdhBtLNAqWxNbWyqVdNnyGr6teYlIvnRuHUSGwiaMxZjBtdFxVUmD+VU65op6RbZZjlsOIk4bFyhZbCyl01Nla2tLucUp2SuyDbcdRBwdY4kXkqYsS822Gap0ehZXrtJY5PTa5SBsNRZ6ItJqROyvDLQj79vlICzA+DJbk72ipkhIRzuqTwY6eW+q3Do3SdlFTA5XSjqbxLva8C6GxY9Mbckxtx1EQiGUyVhaOlwxFU9Sli5zW3pYhqjCjrUdBGPsJmNsjzH2Ysr5Lwd/X9CN6xImC7jKbrChF7+UokyR1YgbNKlTlbRhrsjwOhEtB8EYuw4AnPM7AEZhOMYLjLH3AOzrxBUnUXn6Glb7qbinV1i6obmixJ6MKR0N2HBf8/7nALwVHO8D2ANwN3bN85zz25rxJGPLsMDzjKTlwlhMGkhVeRR07fNM0JISMk88X3Xy2qSsqBhtCfpDjC0A9yLhawnX7GYNQVpBgcIsXNzRNQM94HwDEEFJiwFwvh65VtUC2tArylpNOwAmkfyZrgDT5Xl4si7zVJeiZWnFKk1NdHsQuXDOXwUAxtizjLG9YDiSyMHBQaasp4Qw4hUTiRdKwUIyPdQ53wAOfw34wZ9IxyB6Ar2ZjEj05G+9qQz/6h8AvTEwOA7uXQf8E8A/B+AB5xsC/YeANwOED0yuCAweyGtnS8B0CCw9lOHJmryufyx1Ot8A/DHgTxJk9YHJKjA4CmQNgekAWHoUkTUB+qdznfqPAW8qdZisC/SPAE8As4GsyIMj6e+mK1LHwTiQtQ70Qp0QpONRIMsHzq94WHogMPOBb/2T1Ks3DYrBl3+9IDzzgY9/A/jMX0udZr5Ab2Km3Kyp1JF0nJyc4EFOncpD10GMAGwHx1sADqMng4nJe8EQ4xDAbpawnZ2dzMi8eOtRoCuW1A1OfWwXlZUlOy9exS6i8IHTbeA7/yArnzeTrWHILHb9LCi1b78ZRDeVOome1NWbBWHfgzcT8ETobDx4U+lkRU9e05vN06AmS16bKCtopef3yuOse71ZEI5XaoV0zJYW8+cifyPhH/8u8NHnpazP/TmwdVfKzsVwmatSxo6Xl5exmVOn8tB1EG8CYMHxLoA7AMAY2+KcjwBwzCcnnwHwumZ82ZgcByZdgoK9eoUx5WTdwwe/LfC/X5w7hbCS5EbjX75eYNFwhJ8djtcNM7I8AELp3qS4yqYjj9CRvPuqhye+KfDZl9OvLVzmBUmcZ1JpnGpGa0TGOb8LAIyxPQCjMAzg7cj5LzDGbgJ4L3LePkwu/CkwKx0a/Xf/HvjhH0V7DPF0qJin8qSE4vVNUOXr78DPfhP41r/IoU3oOEphSYU2/d6L9hwE5/yNhN9uZJ2vjQRPLDLOLV6YfK7Qk4YCxvL4F4B3/m6x+2ueKo3Wi8mPh1XuVSHvvizZyecm68C3vgH88i3gif9Uj7EopeWo2HE03NVNa1XVLnR9icVMC4WURXA+vH66AsBLG0rktZg6PYq4jlk6591rsmej00swVH194HyzdzEnAiB78ZximRe5Ng9lO6Z3MSTxQlMxmYtre73scIHMTr03xg+/CPCvIqVexO/SKSKd4UqeLJ2KqaNT/F5zJvz9PwW+95cpJ7N6lADg+4vhAlxcm9Lg6NqxDesgmqUXS37FS69LX3tpTgIZOW9yzB2XVaWzUemN6GByadZlnYpOECuXeRZV2rEmbjmIvMm/pKcH0eN4uGzcOsbheVg0TJ2JxTxZeZUlWvx5DiArnXXqlEdcJxVZse55tBzjlViFPNur044VcctBpGBy3v8inDWWSykUMx16e1qP4hSZOKwqLsP5lSYuZYhR5SBOW1ZXJylTM66CTTqyMrpUOylETn1SmUh0hSZ1MBR3YB8mHySbtOOqmhUnHYQqhTqZdT3H9jxFm3WxR2ETpapo4Mjtcs46A6+ydMJBFCLPawfGYqLrNngE9I7RQMNql8GbIU+n4joP7st3WeIUKvMCNlFLpbZpP4iuUbiAcxZR/fzXgBtfwuV1zsk3FI01Aep9qPDpvwF+5c9KttRZDUzEHlwrkXY5CANe3ggWPaYyR129jzonPBVE11mmBe04N0WdXyhVIAOSrkhen1Rz97twdBasMgSglg5Xej3qq2fNRJvxGDOFhSuK9lY69xSj5J4NVaJszmHhzpfR5dxgboxt9l4dXEzznNJlnhZWteMa7d4tB+EKCQV4yaiMNZR1tbhN9UaqpNp0NK5ll4cYjWd+WS45D1c0qXKxk0lKvcEg8TS0LNmqV1r6nRti1EzlVddIPbFl8jCKK06vJiwYCpeFHATInO1CoL4Xv2IxO24Iqm/iFKF1DsKVMvZKl55IObYJW4YQilhkPAtJabAH0joHsUA0Y2vKZPUXxWyt5HViSR7UNlrTiKhmZ9FuBxGllYuX2kJW2VjiPDLwYn+rjax4LCbS0x0HUbPnzSwcIVyet6oASzLD6IOaitbs1PydV7cchKu1qsTKOSKE8qpJ3HIQGRXLajNKSDf5iDpQbFAsKBMLkrCAWw6iYmotnLyneURAfZlktIPaYAtwEXOXV1JaRZm19JXYj9vvODSO0vtolrw5XHEa3HIQVb5RFyHXTlTnFJLiSt2XvCnDUvVYGkuaG8OWdBgk73uhmrjlIBLegrNizFbCcYmL/2zF6sRVRPZGPwuYHELY0NtIwS0HYcvTAEPx2mUWdqWmCZRK1WSl9rxKGjrRuSFGnWh+ezNEfctUK/pEJXBhCXg2mal28bGTAQeh/enY4MvdIwDXOeevqp43SpVdNcOyi28tajLeuiquF4srHnZQJ8VGQStVhmzNa3oOgjF2HQA453cAjMJw0fME4G6PQQWTe2DU2DvJqah17x3RBLpDjOcgewcAsA9gT/G8HjU91cilzBgy/QMdRS8sQJ29k6rkWlqZKvhIk3EsmIPYAnAvEr6meL4dlFl3b2xPyrLXqmJyI9oqN5/R21Gq+K0G37WwpaFLQHsOwiQHBweZ55+azRbKML6U4FI44zFo7r0obi9lZDnUyzSAGwu4lN6Dgob9xCLKuteD/HxKGTs+PTnBg5w6lYeugxgB2A6OtwAcKp5fYGdnJzOy+KRLPNMuffDE81JLPevepHBmusrIqqSXqTo9pj2dVgP16VR5mVcgK+va4fIyNnLqVB66Q4w3AewGx7sA7gAAY2wr67wtNDk9WH5HKZNYkQjDtFGnbKq0Yy0HwTm/CwCMsT0AozAM4O2c80QXHl4QyTg0Sak9B8E5fyPhtxtZ52ujWwN9whUs/lBOHFpJ6Rxdc3pd09ccoumFUl2hknXy1r3u3SSupjsDG3qvFqyDqB0nh+5pr3s7jfMKXCaiUtV2Zkx+xY7IOQdhjKY9vAtPGDsHFUic7jqIhvFUVmJZiYuJz3MALupULW45iKZb/bIkTBbZ/5k3W/La+oxqNW45CFspsz9hJfWvjZVJNaPcWWNghKyd3rs4SalE1mOehjYA8S4dmKROo7XlJTIV8vYOrScVyjS4WU27HYQtXj6J0p6iLmNxtTdSvsw9az1EBrRpbQZVvmKrIrtkQYjI//Ff8+4qh60VQOk965ywyr0NkbD5ciZl7ZyGGDFseSc/TpoDqdxebXEmOi1ZjT0ZpZdGc9KVdb7Kb2pE7+18D8JWChRw+hV1bq6igqPbxmVSz0YtjQ3WOt+DsGXbrxItgnfxH5FdUet0JtkFUllx2WLHCbjlIOqcdHRxm/NL6GxjkkdTvQBXhiu29JL0cMtBVEnNDqEa87Fl3iBOlZWlIUduiwOo+GNSbjsIV7601YreSJw26BR/mqBya4VDVIvsxS0H0abvIday7b1JTHbtbdFpEaPbAFpQyU2o45aDcOXLWUWMI/WSvHubNzzz6KyDUCH7Xi0T6LlVlYritlYNeunMmIukq7Ax2uIQTKbDlonG2C7pWk9uNbxLVT1jWgdhkLqdjYeCtq7TrJlcq2Cy92bJBF8bidqxDZvWNkrGdy8alZXDR78FjD+ZmhDYsy6gLpr6lseinA9/B+iNgSe/bUh8USy2Y+pBVEFOAT38NPB/fwjX3bMCts6rLMb74HPA+18o6H4smISsg3Y7iKZe986RrRezyneWTMpypUKYzB9LKGnHJrRtt4OwFWvs1IXhSr2ZVbjdsHjtgknIQTSA2nZzbTA8nbUdNTsxF3xmjbTLQbTUixMq6NVwK/yDRXbsloOwKOOqwwoTdQiTawjasNu4WTozj24Tn/hXYLoG/OTzdcdMH+PI4pNfA5Y/bDoVduFWDyJOfKIoL2wJq+8DT90GMKk7ZjvzwxbWvw889W+G38kog0V2q+0gGGM3GWN7jLEXU85/Ofj7gm5cbWLtR8Cn/xbwzkxKtcewqsO8jr1j4Jf+Clj+sYYQG7c7bHpHKcbYdQDgnN8BMArDMV5gjL0HYF8nrrbRmwBP/yPw2b8Att5pOjVFaKHzEQKb7wC/+FXg594C1n/QdILsQ3cO4jkAbwXH+wD2ANyNXfM85/y2ZjwSi7pemSikc/u7wJX/Ad7/PeBHvw9ghpjbzpAVngon1eL3zkQsHD8fCefKUgjHJ/pU0xW9PzdeUbqZ230NWP4IePKbweyMrZOTDdq9roPYAnAvEr6WcM0uY2wPwHXO+atZwg4ODjIje0qI6r6K3GAhDH8GPPMVYO19wJsAP/114PA3gpNTAH7k4qmQYQF85hX5d7oqx829YwA9YLoieyi9U2DWB2ZDedybANMhIPqAP5b3TFdkxfDHgawrgDcF/BNA+MBkGeidA/4ZMBsAsyV5DhNgtiKv6Y9l9s1WZd32x/KB02RV6uOfyHRMlwH/XKZlOgTEAPCPZTomQ5l2fxyouSIrrX8MiEAnbyrvhS/v750B3jkghjKd+3+MHKcn8MR/AJvfkzptvSvz3Gky7Pj09BQPcupUHpU/xQidAmPsWcbYXjAcSWRnZydTlmfynfsaX84qgieAj/+7rKzXvgN89N+yUnqnsgLMgsqEmaws6AHXvg0Mjub2H2+4hbc44XYpjMV7EQ3n3ZsjSykdhmStfghMVgLH1QP8x/L+2ZoU5I89DI4Ern13fm/riNj1cHkZGzl1Ko9cB5EyubgfzjsA2A5+2wJwmHDvvWCIcQhgVyu1HWFwBDz9z/L4coXwLr65aI97s4Mn/kvN6RnFlgbH8H4QuQ6Cc/5Gxuk3AbDgeBfAHQBgjG1xzkcAOOaTk88AeL18UmFHATRNy/OgUu1anndVoNVn55zfBYBgjmEUhgG8HTn/BcbYTQDvRc43j6axkKkRC9jofGzYMCaph8E5v5F1niCIGujclnOdeBfDEmxsEQk1ml4oRRCEvXRv23uCIIrTuSEG0U5oOFMNNMQgCKJKyEGYoI0tIE0IX6JsKbtsHeQgTNDGypTn9NroFFsG7WpNEESlkIMgytHGXlPboKcY7tGVjrmTetKw6RJuOQhbWy0yLCJOS2zCLQdhK7Y6LoLQhBxEHeS0Jq63Nbnpb0lrukAbdUqAHIQJ8oyFehjtoyNl6pyDqM1vm2whOtLapOJKZYqks+oSc8UinHMQVuJKBSAIRchBpKHSmnS9h1AxlLvNQQ6CIIhUyEGYwKIhRmZrG+np1NoqUw+rGWglpSV0sQJ0UWdLqbIkOu0gjGWsRT2ITJpKZ03xUq8oBm0Ykw4ZS0uhvK4VtxxEvCVyxVgMtaCOaFstmjbgWh42nV63HISruOLIYhhLtSv6q6TThWFl5ycpLepRdK4yqZBhqNraulBRVTFlA513EHFaPgnnLG10ekKkOjdrtKVJygpRydyMa60xFtOQU6wOi/LWiINgjF3POHeTMbbHGHvRRFwLVPlCFfVGsmn55GAurpSTJtoOIviy99dTzl0HAM75HQCjLEfSODSfUVws0LoKoqyT4cbJVgeq7SCCyr+fcvo5AKPgeB/AnlZk8QL0vEoyto0VwDgNzvjbWplKY7GtVT0HsQXgXiR8reL43MdkyxRxoK2rVFXQxslUTayapDw4OMj8N93chIh429nq6sL98bBYWZkf9/sQg0HiuVxZngexvDy/d2kJ6M2zTuSkYyHc60EMh/N7l5cXdVpbK6wTfF9Lp1k0f5aWAN9P1UnE0xUNe96iTsMhRCR/ZjnpWojL92VaojpFX72PXCsSZC3oNBhA9PuLstLijcuKl/lwWL7M4zrFyzwrP+LpLmDHoZs7+tjHcuvUwcEBsuhnngXAGHsh4ef9YGiRxwjAdnC8BeAw6+KdnZ1MYdN338VPPvgAT66tAYMBMBwCkwlwfAwsLwODASZnZ8DZGbC6Kgv0+BiYzWQYAMZj+fvKivx9PAaWloDhUN57cjKXdXoKnJ/La31fyhJChkNZvi+vn07l+aUlYGlJyjo9ldf2+5icnMi0rq5Kgz8+xuHhIa49/bSUOR5f1mk4nMs6O5un4+RExhdWnqhOYTpCWefn6Tr1ejJdYf4IIe9N0QlxnQJZPx2P8eSTT8prAXk+1Knfl7KK6hSW0/Gx1K2ITrEyn2SVeVynUFag008++ABPbG7O7Wc8ljLC/ImX+Xgsj0OdTk7ktb4v8zbUKSynUFZoe4OBvP/8fMGOEc+fonYc0UkMh7l1Ko9cB8E5f0NVKGNsi3M+AvAmABb8vAugiFNJZ30dYnsbUFH66tXF8Pb2YviawqgnT5Yi08FgLkMlHXE006GLODhIToOOTjpolNPs/HzRvkyWeVyWSdtLkpXTOyiCiacYN+UfdjPy89sAwDm/G1yzB2AUhgmCcIPcHkQenPPbAG7HfrsROVbugRAEYQdWTVISBGEX5CAIgkiFHARBEKloz0GY5NatW00ngSCICJ6g1WMEQaRAQwyCIFIhB0EQRCrkIAiiBVS1J4tzDqKxzWkIAPl57FIZFNDly8HfpPeRrKHKPVmcchBt2ZzGVcPMy2PHyqBIWl9gjL2H9P1OrKDKPVmcchC1bk5TEY4bZl4eO1EGAUXS+jzn/JmCby7bitaeLFatg9DElc1pngPwVnAcGmb8Jbbng3dcbCMvj10pA6BYWneDXut1zvmr9STLLpzqQbSEwobpwji+zXDOXw16D9cCR+EiSnuyxLGqB1Hn5jQ2E7ZWjLFnGWN7FnVx8/LYpTLITGtgi/eCntwh5H4mzmBqTxarHIRVm9NokOPoXDbMxDy2sQwKkKcLx3wO6BkAr9eewoJE92SJDE3fBnCDc36XSUrtyeLUUusgI76CyBidMfZOuP9EULn2Aezaug9FMCnJOOdvBEOIO0EhbnHOw0nL/eD4dQCv27TRTlIeu1YGIQV1uRec7+QchFMOoi2QYRKuQA6CIIhU6CkGQRCpkIMgCCIVchAEQaRCDoIgiFTIQRAEkQo5CIIgUiEHQRBEKv8P71rQNkWS7G8AAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 288x216 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Switch the model and likelihood into the evaluation mode\n",
    "model.eval()\n",
    "likelihood.eval()\n",
    "\n",
    "# Start the plot, 4x3in\n",
    "f, ax = plt.subplots(1, 1, figsize=(4, 3))\n",
    "\n",
    "n = 150\n",
    "test_x = torch.zeros(n ** 2, 2)\n",
    "test_x[:, 0].copy_(torch.linspace(-1, 1, n).repeat(n))\n",
    "test_x[:, 1].copy_(torch.linspace(-1, 1, n).unsqueeze(1).repeat(1, n).view(-1))\n",
    "# Cuda variable of test data\n",
    "test_x = test_x.cuda()\n",
    "\n",
    "with torch.no_grad():\n",
    "    predictions = likelihood(model(test_x))\n",
    "\n",
    "# prob<0.5 --> label -1 // prob>0.5 --> label 1\n",
    "pred_labels = predictions.mean.ge(0.5).float().cpu()\n",
    "# Colors = yellow for 1, red for -1\n",
    "color = []\n",
    "for i in range(len(pred_labels)):\n",
    "    if pred_labels[i] == 1:\n",
    "        color.append('y')\n",
    "    else:\n",
    "        color.append('r')\n",
    "        \n",
    "# Plot data a scatter plot\n",
    "ax.scatter(test_x[:, 0].cpu(), test_x[:, 1].cpu(), color=color, s=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
